import numpy as np
import emdp.emdp.gridworld as gw
from emdp.emdp import actions



def build_MDP():
    size = 19
    P = gw.build_simple_grid(size=size, p_success=1)
    R = np.zeros((P.shape[0], P.shape[1])) # initialize a matrix of size |S|x|A|
    pos = 1

    p0 = np.ones(P.shape[0])#/P.shape[0] # Start in the upper left corner
    gamma = 0.99
    
    ##------ Make 4 rooms
    for j in range(size**2):
      for i in range(size**2):
        # States in the horizontal wall except doors
        is_line = lambda x: x> (size//2)*size-1 and x < (size//2)*size+size and x != (size//2)*size+size//4 and x != (size//2)*size+3*size//4
        k = size*(i%size) + i//size
        if is_line(i):
          p0[i] = 0
          p0[k] = 0
          for a in range(4):
            if P[j, a, i] == 1:
              P[j, a, i] = 0
              P[j, a, j] = 1
            elif P[j, a, k] == 1:
              P[j, a, k] = 0
              P[j, a, j] = 1
      ##------ 
    p0 = p0/p0.sum()
    idx = 82
    R[idx,:] = 1

    terminal_states = []
    return gw.GridWorldMDP(P, R, gamma, p0, terminal_states, size)


class Wrapper():
  def __init__(self, build_mdp):
    self.build_mdp = build_mdp
  
  def reset(self):
    self.mdp = self.build_mdp()
    self.eps_step = 0
    state = self.mdp.reset()
    state = state.reshape((1, 19, 19))
    return state 

  def step(self, action):
    self.eps_step += 1
    state, reward, done, _ = self.mdp.step(action)
    done = True if reward or (self.eps_step > 500) else False
    state = state.reshape((1, 19, 19))
    return state, reward, done, {}
